Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Touches for ScoreGradELBO #99

Merged
merged 29 commits into from
Dec 6, 2024
Merged

Minor Touches for ScoreGradELBO #99

merged 29 commits into from
Dec 6, 2024

Conversation

Red-Portal
Copy link
Member

No description provided.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 801a76b Previous: b6083ed Ratio
normal/RepGradELBO + STL/meanfield/Zygote 15540028767 ns 15262008164 ns 1.02
normal/RepGradELBO + STL/meanfield/ForwardDiff 3336745945 ns 3126611341 ns 1.07
normal/RepGradELBO + STL/meanfield/ReverseDiff 3252943822 ns 3207060677 ns 1.01
normal/RepGradELBO + STL/fullrank/Zygote 15109365150 ns 15220603739 ns 0.99
normal/RepGradELBO + STL/fullrank/ForwardDiff 3677404700 ns 3439213244 ns 1.07
normal/RepGradELBO + STL/fullrank/ReverseDiff 5745669202 ns 5765830869 ns 1.00
normal/RepGradELBO/meanfield/Zygote 7223742806 ns 7257011712 ns 1.00
normal/RepGradELBO/meanfield/ForwardDiff 2387860875 ns 2294402627.5 ns 1.04
normal/RepGradELBO/meanfield/ReverseDiff 1474387789 ns 1435638209 ns 1.03
normal/RepGradELBO/fullrank/Zygote 7231934077 ns 7176706497 ns 1.01
normal/RepGradELBO/fullrank/ForwardDiff 2566869136 ns 2493123842.5 ns 1.03
normal/RepGradELBO/fullrank/ReverseDiff 2581003358 ns 2529829241 ns 1.02
normal + bijector/RepGradELBO + STL/meanfield/Zygote 23549114509 ns 23064269596 ns 1.02
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff 10271662690 ns 9877706080 ns 1.04
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 5166012174 ns 4966628075 ns 1.04
normal + bijector/RepGradELBO + STL/fullrank/Zygote 23512263396 ns 23022827020 ns 1.02
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff 10840780388 ns 10583173397 ns 1.02
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 8241021709 ns 8173149369 ns 1.01
normal + bijector/RepGradELBO/meanfield/Zygote 14938818567 ns 14449178740 ns 1.03
normal + bijector/RepGradELBO/meanfield/ForwardDiff 9351470129 ns 8845141905 ns 1.06
normal + bijector/RepGradELBO/meanfield/ReverseDiff 3197315608 ns 3004891758 ns 1.06
normal + bijector/RepGradELBO/fullrank/Zygote 14961186110 ns 14597576336 ns 1.02
normal + bijector/RepGradELBO/fullrank/ForwardDiff 9897958470 ns 9482666704 ns 1.04
normal + bijector/RepGradELBO/fullrank/ReverseDiff 4617350745 ns 4449258051 ns 1.04

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented Oct 1, 2024

Codecov Report

Attention: Patch coverage is 95.00000% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.65%. Comparing base (616b581) to head (5ff79e3).
Report is 5 commits behind head on master.

Files with missing lines Patch % Lines
src/objectives/elbo/scoregradelbo.jl 94.11% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #99      +/-   ##
==========================================
- Coverage   93.54%   92.65%   -0.90%     
==========================================
  Files          12       12              
  Lines         372      354      -18     
==========================================
- Hits          348      328      -20     
- Misses         24       26       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Red-Portal
Copy link
Member Author

@willtebbutt Hi, it seems only Mooncake is failing. Could you look into the error messages? If they don't quite make sense, I'll try to pack up a MWE.

@willtebbutt
Copy link
Member

Aha! I've been waiting for an example where this happens -- I've been aware of this PartialTypeVar thing for a while, but haven't had a chance to dig into it properly. I'll figure out what's going on :)

@willtebbutt
Copy link
Member

This being said, if you're able to make a MWE that I can run easily on my machine, that would be great. Just whatever the function is that you're differentiating, because I'm not exactly how your tests are structured, and therefore where to find the correct function.

Red-Portal and others added 2 commits October 3, 2024 20:34
Sampling is now done out of the AD path.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@Red-Portal
Copy link
Member Author

@arnauqb Hi, I shifted things around a bit. Could you take a look and see if you have any comments? Also, the default value for the baseline, which you set as 1, is somewhat curious. Shouldn't this be 0 for the control variate to have no effect? Or maybe I'm misunderstanding something.

@arnauqb
Copy link
Contributor

arnauqb commented Oct 4, 2024

Hi @Red-Portal thanks for tidying this up, looks good to me. Yes, you are right that baseline should be 0 to have no effect, not sure what I set it to 1 😅

@Red-Portal Red-Portal added this to the v0.3.0 milestone Oct 4, 2024
@arnauqb
Copy link
Contributor

arnauqb commented Oct 9, 2024

@Red-Portal After a bit of testing, I run into an error when sampling from two priors like this:

@model make_model(data)
    p1 ~ Dirichlet(3, 1.0)
    p2 ~ Dirichlet(2, 1.0)
    data ~ CustomDistribution(vcat(p1, p2))
end

and then transforming q using Bijectors.

When using the score estimator I get an input missmatch (3 !=5) error when it tries to differentiate through the scoregradelbo objective. I believe that there is something happening inside the Bijectors extension where the priors get flatten somehow for q but when passed q_stop something is different.

To fix this, we can sample samples_stop outside the gradient calculation. That is, do samples_stop = rand(rng, q_stop, obj.n_samples) and then pass samples_stop to aux, which probably makes more sense anyways.

@Red-Portal
Copy link
Member Author

@arnauqb I think that's how it is done in this PR now. However, that error is quite peculiar, and it appears that there might be something wrong with Bijectors. If you have time, could you try to come up with a MWE and take it to Bijectors?

@arnauqb
Copy link
Contributor

arnauqb commented Oct 9, 2024

The problem may be that I am using this: https://github.com/TuringLang/Turing.jl/blob/40a0d84b76e8e262e32618f83e6b895b34177d95/src/variational/advi.jl#L23
to do the automatic transformation:

using AdvancedVI
using ADTypes
using DynamicPPL
using DistributionsAD
using Distributions
using Bijectors
using Optimisers
using LinearAlgebra
using Zygote

function wrap_in_vec_reshape(f, in_size)
    vec_in_length = prod(in_size)
    reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
    out_size = Bijectors.output_size(f, in_size)
    vec_out_length = prod(out_size)
    reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
    return reshape_outer  f  reshape_inner
end

function Bijectors.bijector(
        model::DynamicPPL.Model, ::Val{sym2ranges} = Val(false);
        varinfo = DynamicPPL.VarInfo(model)
) where {sym2ranges}
    num_params = sum([size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata)])

    dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...)

    num_ranges = sum([length(varinfo.metadata[sym].ranges)
                      for sym in keys(varinfo.metadata)])
    ranges = Vector{UnitRange{Int}}(undef, num_ranges)
    idx = 0
    range_idx = 1

    # ranges might be discontinuous => values are vectors of ranges rather than just ranges
    sym_lookup = Dict{Symbol, Vector{UnitRange{Int}}}()
    for sym in keys(varinfo.metadata)
        sym_lookup[sym] = Vector{UnitRange{Int}}()
        for r in varinfo.metadata[sym].ranges
            ranges[range_idx] = idx .+ r
            push!(sym_lookup[sym], ranges[range_idx])
            range_idx += 1
        end

        idx += varinfo.metadata[sym].ranges[end][end]
    end

    bs = map(tuple(dists...)) do d
        b = Bijectors.bijector(d)
        if d isa Distributions.UnivariateDistribution
            b
        else
            wrap_in_vec_reshape(b, size(d))
        end
    end

    if sym2ranges
        return (
            Bijectors.Stacked(bs, ranges),
            (; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
        )
    else
        return Bijectors.Stacked(bs, ranges)
    end
end
##

function double_normal()
    return MvNormal([2.0, 3.0, 4.0], Diagonal(ones(3)))
end

@model function normal_model(data)
    p1 ~ filldist(Normal(0.0, 1.0), 2)
    p2 ~ Normal(0.0, 1.0)
    ps = vcat(p1, p2)
    for i in 1:size(data, 2)
        data[:, i] ~ MvNormal(ps, Diagonal(ones(3)))
    end
end

data = rand(double_normal(), 100)
model = normal_model(data)

##

d = 3
μ = zeros(d)
L = Diagonal(ones(d));
q = AdvancedVI.MeanFieldGaussian(μ, L)
optimizer = Optimisers.Adam(1e-3)

bijector_transf = inverse(bijector(model))
q_transformed = transformed(q, bijector_transf)
ℓπ = DynamicPPL.LogDensityFunction(model)
elbo = AdvancedVI.ScoreGradELBO(10, entropy = StickingTheLandingEntropy()) # this doesn't
#elbo = AdvancedVI.RepGradELBO(10, entropy = StickingTheLandingEntropy()) # This works

q, _, stats, _ = AdvancedVI.optimize(
	ℓπ,
	elbo,
	q_transformed,
	10;
	adtype = AutoZygote(),
	optimizer = optimizer,
)

and stacktrace:

ERROR: output length mismatch (3 != 2)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] rrule(::typeof(error), 729::String)
    @ ChainRules ~/.julia/packages/ChainRules/vdf7M/src/rulesets/Base/nondiff.jl:155
  [3] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::String)
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:138
  [4] chain_rrule
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:224 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::typeof(error), args::String)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [7] transform
    @ ~/.julia/packages/Bijectors/I8eRc/src/bijectors/stacked.jl:162 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::typeof(transform), ::Stacked{…}, ::SubArray{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [9] Transform
    @ ~/.julia/packages/Bijectors/I8eRc/src/interface.jl:82 [inlined]
 [10] #84
    @ ~/.julia/packages/Bijectors/I8eRc/src/transformed_distribution.jl:221 [inlined]
 [11] #665
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:188 [inlined]
 [12] iterate
    @ ./generator.jl:47 [inlined]
 [13] _collect(c::Base.OneTo{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:854
 [14] collect_similar
    @ ./array.jl:763 [inlined]
 [15] map
    @ ./abstractarray.jl:3285 [inlined]
 [16] ∇map(cx::Zygote.Context{…}, f::Bijectors.var"#84#85"{…}, args::Base.OneTo{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:188
 [17] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/array.jl:214 [inlined]
 [18] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [19] rand
    @ ~/.julia/packages/Bijectors/I8eRc/src/transformed_distribution.jl:218 [inlined]
 [20] _pullback(::Zygote.Context{…}, ::typeof(rand), ::Random.TaskLocalRNG, ::MultivariateTransformed{…}, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [21] estimate_scoregradelbo_ad_forward
    @ ~/code/AdvancedVI.jl/src/objectives/elbo/scoregradelbo.jl:102 [inlined]
 [22] _pullback(::Zygote.Context{…}, ::typeof(AdvancedVI.estimate_scoregradelbo_ad_forward), ::Vector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [23] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [24] pullback(::Function, ::Vector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88
 [25] withgradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:205
 [26] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/qcr6f/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:92 [inlined]
 [27] value_and_gradient!(f::Function, grad::Vector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::Vector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/qcr6f/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:105
 [28] value_and_gradient!
    @ ~/.julia/packages/DifferentiationInterface/qcr6f/src/fallbacks/no_prep.jl:67 [inlined]
 [29] value_and_gradient!(ad::AutoZygote, f::Function, x::Vector{…}, aux::@NamedTuple{…}, out::DiffResults.MutableDiffResult{…})
    @ AdvancedVI ~/code/AdvancedVI.jl/src/AdvancedVI.jl:46
 [30] estimate_gradient!(rng::Random.TaskLocalRNG, obj::ScoreGradELBO{…}, adtype::AutoZygote, out::DiffResults.MutableDiffResult{…}, prob::DynamicPPL.LogDensityFunction{…}, params::Vector{…}, restructure::Optimisers.Restructure{…}, state::Nothing)
    @ AdvancedVI ~/code/AdvancedVI.jl/src/objectives/elbo/scoregradelbo.jl:134
 [31] optimize(::Random.TaskLocalRNG, ::DynamicPPL.LogDensityFunction{…}, ::ScoreGradELBO{…}, ::MultivariateTransformed{…}, ::Int64; adtype::AutoZygote, optimizer::Adam, averager::NoAveraging, show_progress::Bool, state_init::@NamedTuple{}, callback::Nothing, prog::ProgressMeter.Progress)
    @ AdvancedVI ~/code/AdvancedVI.jl/src/optimize.jl:76
 [32] optimize
    @ ~/code/AdvancedVI.jl/src/optimize.jl:50 [inlined]
 [33] #optimize#26
    @ ~/code/AdvancedVI.jl/src/optimize.jl:127 [inlined]
 [34] top-level scope
    @ ~/code/AdvancedVI.jl/score_bug.jl:97

@Red-Portal
Copy link
Member Author

Red-Portal commented Oct 14, 2024

Hi @arnauqb , sorry for the late reply. Seems like this happens only on Zygote right?

Edit: Aha! This is because Zygote attempts to differentiate through rand. This should be addressed by this PR since it doesn't call rand in the ad path anymore.

@yebai
Copy link
Member

yebai commented Oct 23, 2024

Everything else is passing other than Enzyme tests.

@arnauqb
Copy link
Contributor

arnauqb commented Oct 28, 2024

@Red-Portal I think this line

https://github.com/TuringLang/AdvancedVI.jl/blob/45b37c111ce05f2fb01f195b8f58237bdb3a66c5/src/objectives/elbo/scoregradelbo.jl#L88C5-L88C85

should be moved outside the gradient calculation (and passed through aux). Otherwise, the gradient estimator is differentiating through the "simulator". If the simulator is not differentiable then this is just the standard score, but if it is, I think it would also be using pathwise gradients.

@Red-Portal
Copy link
Member Author

Red-Portal commented Oct 28, 2024

@arnauqb

Hi! Hmm it wouldn't be using the path gradient here since samples is calculated outside. My only concern here would be the case where Zygote tries to differentiate through logpi anyways and throws an error if it can't just like when it did with rand(q). Did you experience this happen?

@arnauqb
Copy link
Contributor

arnauqb commented Oct 28, 2024

Something similar. I have a custom Distribution for which I implement a logpdf that is calculated by sampling from the distribution and measuring a loss with respect to the value passed (think it as doing generalized variational inference, if you are familiar with that). I have a custom Zygote rule for differentiating the rand of the distribution and so when calling logdensity it goes through there.

Now you are right that samples don't carry a gradient so ultimately this won't affect the results, I think, but may be a performance boost to take it out.

@Red-Portal
Copy link
Member Author

I see. Uhghh it is annoying that Zygote can't be smarter. Anyways thanks for mentioning this, I'll try to improve it.

@Red-Portal
Copy link
Member Author

@yebai This one is also ready to go

@Red-Portal Red-Portal requested a review from yebai December 4, 2024 02:51
@yebai yebai requested a review from mhauru December 4, 2024 10:38
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the theory here, so just commenting on the software engineering. A few questions about tests, and requesting some more documentation.

src/objectives/elbo/scoregradelbo.jl Show resolved Hide resolved
src/objectives/elbo/scoregradelbo.jl Show resolved Hide resolved
src/objectives/elbo/scoregradelbo.jl Outdated Show resolved Hide resolved
test/inference/repgradelbo_distributionsad.jl Outdated Show resolved Hide resolved
test/inference/repgradelbo_locationscale.jl Outdated Show resolved Hide resolved
test/inference/scoregradelbo_locationscale.jl Outdated Show resolved Hide resolved
test/inference/scoregradelbo_locationscale_bijectors.jl Outdated Show resolved Hide resolved
test/interface/repgradelbo.jl Show resolved Hide resolved
Comment on lines +2 to +11
AD_scoregradelbo_interface = if TEST_GROUP == "Enzyme"
[AutoEnzyme()]
else
[
AutoForwardDiff(),
AutoReverseDiff(),
AutoZygote(),
AutoMooncake(; config=Mooncake.Config()),
]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this gets used a lot, could it be in a test utils module? Orthogonal to this PR though, can be done later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I am planning to iron this out in a separate PR

test/interface/scoregradelbo.jl Show resolved Hide resolved
@Red-Portal Red-Portal requested a review from mhauru December 4, 2024 20:55
@Red-Portal
Copy link
Member Author

@mhauru Oh I missed a minor bug. Let me fix this tomorrow and I'll ping you again when I'm done!

@Red-Portal
Copy link
Member Author

@mhauru Yeah seems ready to go now!

@mhauru
Copy link
Member

mhauru commented Dec 6, 2024

Thanks @Red-Portal!

@mhauru Oh I missed a minor bug. Let me fix this tomorrow and I'll ping you again when I'm done!

Could something be added to the test suite that would have caught this bug?

@Red-Portal
Copy link
Member Author

Red-Portal commented Dec 6, 2024

Hi @mhauru ! Nah, no need. It was a stupid typo that failed all the tests.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy on the code side. Thanks @Red-Portal!

@Red-Portal Red-Portal merged commit 1dbf2ac into master Dec 6, 2024
12 of 17 checks passed
@Red-Portal Red-Portal deleted the tidy_scoregradelbo branch December 6, 2024 14:59
@arnauqb
Copy link
Contributor

arnauqb commented Dec 6, 2024

Thank you @Red-Portal , this is very useful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants